{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Omniglot Character Set Classification Using Prototypical Network "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Now we will see how to use prototypical networks to perform the classification task. We use omniglot dataset for performing classification. Omniglot dataset comprises of 1,623 handwritten characters from 50 different alphabets and each character has 20 different examples written by different people. Since we want our network to learn from little data, we train them in the same way. We sample five examples from each class and use that as our support set. We learn the embeddings of our support set using a sequence of four convolution blocks as our encoder and build the class prototype. Similarly, we sample five examples from each class for query set, learn the query set embeddings and predict the query set class by comparing the Euclidean distance between the query set embeddings and class prototype. Let us better understand this by going through it step by step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we import all the required libraries,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "from PIL import Image\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Understanding the dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we will explore and see what we got in our data. As we know we have different characters from different alphabets and each character has twenty different variants written by different people. Let's plot and check some of them. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us plot one character from Japanese alphabet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGkAAABpAQAAAAAR+TCXAAAAqUlEQVR4nO3OMQ6CQBCF4X9wEyi3\ntOQYlhQeywKiByOehNJSuzUhjIXE8DDGimjBVPvlzeyMOZNqMqRWvvP0+zM8wqBpdTOz4kWPQA8E\nuJt8lb4uyt19wgQegIDhnCfNx+drGPcexiTKz70uSqXwqmlXCdta2GhzJvQgHAphj5AovJTCrhK2\ns9lauNF0r9zNZv+ebmYWNU2fmkvlFqAAAgC5S7rczSuX5gMYvSMsiQG6eAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=1 size=105x105 at 0xCF85198>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Image.open('data/images/Japanese_(katakana)/character13/0608_01.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "the same alphabet in different variation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGkAAABpAQAAAAAR+TCXAAAAy0lEQVR4nO2SsQ3CQBAE594ICAhc\nAIFLoAKgNJsKKIFSEJVABwghYZDxEfhtWCMRIgIuG+3e/t3/m/NSRUDqj294+I0x9ooG3NR8UTwu\nBTeiei7qiTziGLin1por4BxtAWoo0tgzALf3mUcdbgErBy3OgKF36gQW5TMq8d2ceixRVaob6UFk\ngqWq66VgLqpHblUTtQ6CYarJmaBte1P9ALpi1jPnn3rTnnkFVwCad3SL8QGaz9L9nFhPHGqUVe1F\nN1GJv6qfN/rjt/EBTzonExbNmAgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=1 size=105x105 at 0xD0222B0>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Image.open('data/images/Japanese_(katakana)/character13/0608_13.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us see sanskrit alphabet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGkAAABpAQAAAAAR+TCXAAAAvUlEQVR4nO3TMQrCMBgF4PfHgB0c\nOjoI9ig5ikfwBrXgoRz1Bh6hR+jg0EJJHAT7XtGCi1gw25eXkP8niSXQqBxk/PkRKzNzmqa3i1/Q\nBpYpXeI3ap4lu0zTLjDbyb3NMOkA1LVnnrcts9K9TshfxwFYTZ6rVVmvaVSWljMDpKoCPXN00Gag\nB7BIR118UEr71NMjzYRxlOZKTq+4MU9oCmKPmtOIEIjJTIt0EHrmHmtmht2THljSJf3QA54b7wDY\nJdokvv1PAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=1 size=105x105 at 0xD0225C0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Image.open('data/images/Sanskrit/character13/0863_09.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGkAAABpAQAAAAAR+TCXAAAA+ElEQVR4nO3SsW2EMBTG8b+NFV2B\nBCMwQjaAkTJBRDa5TUI2yAiko3SkFBDBvRTmzHsprk2Ko0D6yebzZxsnqOfFY547/5Bvrk4MIB74\ndFVMo5c0MO+Tv+y375YdtBc4paiNKiJV3JMLsZMhp+dW9Y2SSzCczKg0NIrf0ClO0CuecxUPbH2p\no1ai5lyYGrE05NFwHAyd3YKwXbcQAFye4IEGRJWsgeHgCRgPBtwYj6hCWGbVCh4wdKuhOQ1bknO+\nFQ/wZE5S5lsbZDWdebYLSb94xW6YgmIzNjqq/ki/0c4yH7VPr0Iv5LheS0ocXg3bVif/6nzn/+IP\nc/48QDbAE1QAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=1 size=105x105 at 0xD0226A0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Image.open('data/images/Sanskrit/character13/0863_13.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How can we convert this image into an array? we can use np.array function to convert these images into an array and then we reshape it to 28*28"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "image_name = 'data/images/Sanskrit/character13/0863_13.png'\n",
    "alphabet, character, rotation = 'Sanskrit/character13/rot000'.split('/')\n",
    "rotation = float(rotation[3:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "         0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  0.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,\n",
       "         0.,  0.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "         0.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  0.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  1.,  1.,  1.,\n",
       "         1.,  1.,  0.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  1.,  1.,\n",
       "         1.,  0.,  0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  0.,  0.,  0.,\n",
       "         0.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.],\n",
       "       [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,\n",
       "         1.,  1.]], dtype=float32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(Image.open(image_name).rotate(rotation).resize((28, 28)), np.float32,copy=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, that we have understood, what is in our dataset, let us load our dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "root_dir = 'data/'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have the splitting details in the /data/omniglot/splits/train.txt file which has the language name, character number, rotation information and images in /data/omniglot/data/ directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_split_path = os.path.join(root_dir, 'splits', 'train.txt')\n",
    "\n",
    "with open(train_split_path, 'r') as train_split:\n",
    "    train_classes = [line.rstrip() for line in train_split.readlines()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#number of classes\n",
    "no_of_classes = len(train_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we set the number of examples to 20, as we have 20 example per class in our dataset, and also we set image width and height to 28 x 28:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#number of examples\n",
    "num_examples = 20\n",
    "\n",
    "#image width\n",
    "img_width = 28\n",
    "\n",
    "#image height\n",
    "img_height = 28\n",
    "channels = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we initialize our training dataset with a shape as a number of classes, number of examples, image height and image width:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train_dataset = np.zeros([no_of_classes, num_examples, img_height, img_width], dtype=np.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we read all the images, convert it to numpy array and store it our train_dataset array with their label and values, that is,  train_dataset = [label, values]:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for label, name in enumerate(train_classes):\n",
    "    alphabet, character, rotation = name.split('/')\n",
    "    rotation = float(rotation[3:])\n",
    "    img_dir = os.path.join(root_dir, 'data', alphabet, character)\n",
    "    img_files = sorted(glob.glob(os.path.join(img_dir, '*.png')))\n",
    "  \n",
    "    \n",
    "    for index, img_file in enumerate(img_files):\n",
    "        values = 1. - np.array(Image.open(img_file).rotate(rotation).resize((img_width, img_height)), np.float32, copy=False)\n",
    "        train_dataset[label, index] = values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4112, 20, 28, 28)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have loaded our training data, we need to create embeddings for them. We generate the embeddings using a convolution operation as our inputs are images. So, we define a convolutional block with 64 filters with batch normalization and ReLU as the activation function. Followed by we perform max pooling operation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def convolution_block(inputs, out_channels, name='conv'):\n",
    "\n",
    "    conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding='SAME')\n",
    "    conv = tf.contrib.layers.batch_norm(conv, updates_collections=None, decay=0.99, scale=True, center=True)\n",
    "    conv = tf.nn.relu(conv)\n",
    "    conv = tf.contrib.layers.max_pool2d(conv, 2)\n",
    "    \n",
    "    return conv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we define our embedding function which gives us the embedding comprising of four convolutional blocks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_embeddings(support_set, h_dim, z_dim, reuse=False):\n",
    "\n",
    "        net = convolution_block(support_set, h_dim)\n",
    "        net = convolution_block(net, h_dim)\n",
    "        net = convolution_block(net, h_dim) \n",
    "        net = convolution_block(net, z_dim) \n",
    "        net = tf.contrib.layers.flatten(net)\n",
    "        \n",
    "        return net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remember, we don't use our whole dataset for training, since we are one shot learning, we sample some data points from each class as a support set and train the network using the support set in an episodic fashion. \n",
    "\n",
    "\n",
    "Now we define some of the important variables, we consider a 60-way 5-shot learning scenario:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#number of classes\n",
    "num_way = 60  \n",
    "\n",
    "#number of examples per class for support set\n",
    "num_shot = 5  \n",
    "\n",
    "#number of query points\n",
    "num_query = 5 \n",
    "\n",
    "#number of examples\n",
    "num_examples = 20\n",
    "\n",
    "h_dim = 64\n",
    "\n",
    "z_dim = 64"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we initialize placeholders for our support set and query set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "support_set = tf.placeholder(tf.float32, [None, None, img_height, img_width, channels])\n",
    "query_set = tf.placeholder(tf.float32, [None, None, img_height, img_width, channels])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And we store the shape of our support set and query set in support_set_shape and query_set_shape respectively:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "support_set_shape = tf.shape(support_set)\n",
    "query_set_shape = tf.shape(query_set)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the number of classes and number of data points in the support set and number of data points in the query set for initializing our support and query sets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_classes, num_support_points = support_set_shape[0], support_set_shape[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_query_points = query_set_shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we define the placeholder for our label:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "y = tf.placeholder(tf.int64, [None, None])\n",
    "\n",
    "#convert the label to one hot\n",
    "y_one_hot = tf.one_hot(y, depth=num_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we generate the embeddings for our support set using our embedding function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "support_set_embeddings = get_embeddings(tf.reshape(support_set, [num_classes * num_support_points, img_height, img_width, channels]), h_dim, z_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We compute the prototype of each class which is the mean vector of the support set embeddings of the class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "embedding_dimension = tf.shape(support_set_embeddings)[-1]\n",
    "\n",
    "class_prototype = tf.reduce_mean(tf.reshape(support_set_embeddings, [num_classes, num_support_points, embedding_dimension]), axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we use our same embedding function for getting embeddings of the query set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "query_set_embeddings = get_embeddings(tf.reshape(query_set, [num_classes * num_query_points, img_height, img_width, channels]), h_dim, z_dim, reuse=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that, we have the class prototype and query set embeddings, we define a distance function which gives us the distance between the class prototypes and query set embeddings:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def euclidean_distance(a, b):\n",
    "\n",
    "    N, D = tf.shape(a)[0], tf.shape(a)[1]\n",
    "    M = tf.shape(b)[0]\n",
    "    a = tf.tile(tf.expand_dims(a, axis=1), (1, M, 1))\n",
    "    b = tf.tile(tf.expand_dims(b, axis=0), (N, 1, 1))\n",
    "    return tf.reduce_mean(tf.square(a - b), axis=2)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculate the distance between the class prototype and query set embeddings:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "distance = euclidean_distance(query_set_embeddings,class_prototype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we get the probability for each class as a softmax to the distance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "predicted_probability = tf.reshape(tf.nn.log_softmax(-distance), [num_classes, num_query_points, -1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute the loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "loss = -tf.reduce_mean(tf.reshape(tf.reduce_sum(tf.multiply(y_one_hot, predicted_probability), axis=-1), [-1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculate accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "accuracy = tf.reduce_mean(tf.to_float(tf.equal(tf.argmax(predicted_probability, axis=-1), y)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We use Adam optimizer for minimizing the loss:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "train = tf.train.AdamOptimizer().minimize(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we start our tensorflow session and train the model,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "init = tf.global_variables_initializer()\n",
    "sess.run(init)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_epochs = 20\n",
    "num_episodes = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for epoch in range(num_epochs):\n",
    "    \n",
    "    for episode in range(num_episodes):\n",
    "        \n",
    "        # select 60 classes\n",
    "        episodic_classes = np.random.permutation(no_of_classes)[:num_way]\n",
    "        \n",
    "        support = np.zeros([num_way, num_shot, img_height, img_width], dtype=np.float32)\n",
    "        \n",
    "        query = np.zeros([num_way, num_query, img_height, img_width], dtype=np.float32)\n",
    "        \n",
    "        \n",
    "        for index, class_ in enumerate(episodic_classes):\n",
    "            selected = np.random.permutation(num_examples)[:num_shot + num_query]\n",
    "            support[index] = train_dataset[class_, selected[:num_shot]]\n",
    "            \n",
    "            # 5 querypoints per classs\n",
    "            query[index] = train_dataset[class_, selected[num_shot:]]\n",
    "            \n",
    "        support = np.expand_dims(support, axis=-1)\n",
    "        query = np.expand_dims(query, axis=-1)\n",
    "        labels = np.tile(np.arange(num_way)[:, np.newaxis], (1, num_query)).astype(np.uint8)\n",
    "        _, loss_, accuracy_ = sess.run([train, loss, accuracy], feed_dict={support_set: support, query_set: query, y:labels})\n",
    "        \n",
    "        if (episode+1) % 10 == 0:\n",
    "            print('Epoch {} : Episode {} : Loss: {}, Accuracy: {}'.format(epoch+1, episode+1, loss_, accuracy_))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
